{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Please run notebook locally (if you have all the dependencies and a GPU). \n",
    "Technically you can run this notebook on Google Colab but you need to set up microphone for Colab.\n",
    " \n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies.\n",
    "5. Set up microphone for Colab\n",
    "\"\"\"\n",
    "# If you're using Google Colab and not running locally, run this cell.\n",
    "\n",
    "## Install dependencies\n",
    "!pip install wget\n",
    "!apt-get install sox libsndfile1 ffmpeg portaudio19-dev\n",
    "!pip install unidecode\n",
    "!pip install pyaudio\n",
    "\n",
    "# ## Install NeMo\n",
    "BRANCH = 'main'\n",
    "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]\n",
    "\n",
    "## Install TorchAudio\n",
    "!pip install torchaudio>=0.10.0 -f https://download.pytorch.org/whl/torch_stable.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook demonstrates offline and online (from a microphone's stream in NeMo) speech commands recognition "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The notebook requires PyAudio library to get a signal from an audio device.\n",
    "For Ubuntu, please run the following commands to install it:\n",
    "```\n",
    "sudo apt-get install -y portaudio19-dev\n",
    "pip install pyaudio\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook requires the `torchaudio` library to be installed for MatchboxNet. Please follow the instructions available at the [torchaudio Github page](https://github.com/pytorch/audio#installation) to install the appropriate version of torchaudio.\n",
    "\n",
    "If you would like to install the latest version, please run the following command to install it:\n",
    "\n",
    "```\n",
    "conda install -c pytorch torchaudio\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pyaudio as pa\n",
    "import os, time\n",
    "import librosa\n",
    "import IPython.display as ipd\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import nemo\n",
    "import nemo.collections.asr as nemo_asr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample rate, Hz\n",
    "SAMPLE_RATE = 16000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Restore the model from NGC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "mbn_model = nemo_asr.models.EncDecClassificationModel.from_pretrained(\"commandrecognition_en_matchboxnet3x1x64_v2\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since speech commands model MatchBoxNet doesn't consider non-speech scenario, \n",
    "here we use a Voice Activity Detection (VAD) model to help reduce false alarm for background noise/silence. When there is speech activity detected, the speech command inference will be activated. \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Please note the VAD model is not perfect for various microphone input and you might need to finetune on your input and play with different parameters.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "vad_model = nemo_asr.models.EncDecClassificationModel.from_pretrained('vad_marblenet')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Observing the config of the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from omegaconf import OmegaConf\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Preserve a copy of the full config\n",
    "vad_cfg = copy.deepcopy(vad_model._cfg)\n",
    "mbn_cfg = copy.deepcopy(mbn_model._cfg)\n",
    "print(OmegaConf.to_yaml(mbn_cfg))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What classes can this model recognize?\n",
    "\n",
    "Before we begin inference on the actual audio stream, let's look at what are the classes this model was trained to recognize.  \n",
    "\n",
    "**MatchBoxNet model is not designed to recognize words out of vocabulary (OOV).**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = mbn_cfg.labels\n",
    "for i in range(len(labels)):\n",
    "    print('%-10s' % (labels[i]), end=' ')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup preprocessor with these settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set model to inference mode\n",
    "mbn_model.eval();\n",
    "vad_model.eval();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting up data for Streaming Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.core.classes import IterableDataset\n",
    "from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType\n",
    "import torch\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# simple data layer to pass audio signal\n",
    "class AudioDataLayer(IterableDataset):\n",
    "    @property\n",
    "    def output_types(self):\n",
    "        return {\n",
    "            'audio_signal': NeuralType(('B', 'T'), AudioSignal(freq=self._sample_rate)),\n",
    "            'a_sig_length': NeuralType(tuple('B'), LengthsType()),\n",
    "        }\n",
    "\n",
    "    def __init__(self, sample_rate):\n",
    "        super().__init__()\n",
    "        self._sample_rate = sample_rate\n",
    "        self.output = True\n",
    "        \n",
    "    def __iter__(self):\n",
    "        return self\n",
    "    \n",
    "    def __next__(self):\n",
    "        if not self.output:\n",
    "            raise StopIteration\n",
    "        self.output = False\n",
    "        return torch.as_tensor(self.signal, dtype=torch.float32), \\\n",
    "               torch.as_tensor(self.signal_shape, dtype=torch.int64)\n",
    "        \n",
    "    def set_signal(self, signal):\n",
    "        self.signal = signal.astype(np.float32)/32768.\n",
    "        self.signal_shape = self.signal.size\n",
    "        self.output = True\n",
    "\n",
    "    def __len__(self):\n",
    "        return 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_layer = AudioDataLayer(sample_rate=mbn_cfg.train_ds.sample_rate)\n",
    "data_loader = DataLoader(data_layer, batch_size=1, collate_fn=data_layer.collate_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## inference method for audio signal (single instance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def infer_signal(model, signal):\n",
    "    data_layer.set_signal(signal)\n",
    "    batch = next(iter(data_loader))\n",
    "    audio_signal, audio_signal_len = batch\n",
    "    audio_signal, audio_signal_len = audio_signal.to(model.device), audio_signal_len.to(model.device)\n",
    "    logits = model.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)\n",
    "    return logits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "we don't include postprocessing techniques here. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# class for streaming frame-based ASR\n",
    "# 1) use reset() method to reset FrameASR's state\n",
    "# 2) call transcribe(frame) to do ASR on\n",
    "#    contiguous signal's frames\n",
    "class FrameASR:\n",
    "    \n",
    "    def __init__(self, model_definition,\n",
    "                 frame_len=2, frame_overlap=2.5, \n",
    "                 offset=0):\n",
    "        '''\n",
    "        Args:\n",
    "          frame_len (seconds): Frame's duration\n",
    "          frame_overlap (seconds): Duration of overlaps before and after current frame.\n",
    "          offset: Number of symbols to drop for smooth streaming.\n",
    "        '''\n",
    "        self.task = model_definition['task']\n",
    "        self.vocab = list(model_definition['labels'])\n",
    "        \n",
    "        self.sr = model_definition['sample_rate']\n",
    "        self.frame_len = frame_len\n",
    "        self.n_frame_len = int(frame_len * self.sr)\n",
    "        self.frame_overlap = frame_overlap\n",
    "        self.n_frame_overlap = int(frame_overlap * self.sr)\n",
    "        timestep_duration = model_definition['AudioToMFCCPreprocessor']['window_stride']\n",
    "        for block in model_definition['JasperEncoder']['jasper']:\n",
    "            timestep_duration *= block['stride'][0] ** block['repeat']\n",
    "        self.buffer = np.zeros(shape=2*self.n_frame_overlap + self.n_frame_len,\n",
    "                               dtype=np.float32)\n",
    "        self.offset = offset\n",
    "        self.reset()\n",
    "        \n",
    "    @torch.no_grad()\n",
    "    def _decode(self, frame, offset=0):\n",
    "        assert len(frame)==self.n_frame_len\n",
    "        self.buffer[:-self.n_frame_len] = self.buffer[self.n_frame_len:]\n",
    "        self.buffer[-self.n_frame_len:] = frame\n",
    "\n",
    "        if self.task == 'mbn':\n",
    "            logits = infer_signal(mbn_model, self.buffer).to('cpu').numpy()[0]\n",
    "            decoded = self._mbn_greedy_decoder(logits, self.vocab)\n",
    "            \n",
    "        elif self.task == 'vad':\n",
    "            logits = infer_signal(vad_model, self.buffer).to('cpu').numpy()[0]\n",
    "            decoded = self._vad_greedy_decoder(logits, self.vocab)\n",
    "           \n",
    "        else:\n",
    "            raise(\"Task should either be of mbn or vad!\")\n",
    "            \n",
    "        return decoded[:len(decoded)-offset]\n",
    "    \n",
    "    def transcribe(self, frame=None,merge=False):\n",
    "        if frame is None:\n",
    "            frame = np.zeros(shape=self.n_frame_len, dtype=np.float32)\n",
    "        if len(frame) < self.n_frame_len:\n",
    "            frame = np.pad(frame, [0, self.n_frame_len - len(frame)], 'constant')\n",
    "        unmerged = self._decode(frame, self.offset)\n",
    "        return unmerged\n",
    "        \n",
    "    \n",
    "    def reset(self):\n",
    "        '''\n",
    "        Reset frame_history and decoder's state\n",
    "        '''\n",
    "        self.buffer=np.zeros(shape=self.buffer.shape, dtype=np.float32)\n",
    "        self.mbn_s = []\n",
    "        self.vad_s = []\n",
    "        \n",
    "    @staticmethod\n",
    "    def _mbn_greedy_decoder(logits, vocab):\n",
    "        mbn_s = []\n",
    "        if logits.shape[0]:\n",
    "            class_idx = np.argmax(logits)\n",
    "            class_label = vocab[class_idx]\n",
    "            mbn_s.append(class_label)         \n",
    "        return mbn_s\n",
    "    \n",
    "    \n",
    "    @staticmethod\n",
    "    def _vad_greedy_decoder(logits, vocab):\n",
    "        vad_s = []\n",
    "        if logits.shape[0]:\n",
    "            probs = torch.softmax(torch.as_tensor(logits), dim=-1)\n",
    "            probas, preds = torch.max(probs, dim=-1)\n",
    "            vad_s = [preds.item(), str(vocab[preds]), probs[0].item(), probs[1].item(), str(logits)]\n",
    "        return vad_s\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Streaming Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## offline inference\n",
    "Here we show an example of offline streaming inference. you can use your file or download the provided demo audio file. \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Streaming inference depends on a few factors, such as the frame length (STEP) and buffer size (WINDOW SIZE). Experiment with a few values to see their effects in the below cells."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "STEP = 0.25\n",
    "WINDOW_SIZE = 1.28 # input segment length for NN we used for training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wave\n",
    "\n",
    "def offline_inference(wave_file, STEP = 0.25, WINDOW_SIZE = 0.31):\n",
    "    \"\"\"\n",
    "    Arg:\n",
    "        wav_file: wave file to be performed inference on.\n",
    "        STEP: infer every STEP seconds \n",
    "        WINDOW_SIZE : lenght of audio to be sent to NN.\n",
    "    \"\"\"\n",
    "    \n",
    "    FRAME_LEN = STEP \n",
    "    CHANNELS = 1 # number of audio channels (expect mono signal)\n",
    "    RATE = SAMPLE_RATE # sample rate, 16000 Hz\n",
    "   \n",
    "    CHUNK_SIZE = int(FRAME_LEN * SAMPLE_RATE)\n",
    "    \n",
    "    mbn = FrameASR(model_definition = {\n",
    "                       'task': 'mbn',\n",
    "                       'sample_rate': SAMPLE_RATE,\n",
    "                       'AudioToMFCCPreprocessor': mbn_cfg.preprocessor,\n",
    "                       'JasperEncoder': mbn_cfg.encoder,\n",
    "                       'labels': mbn_cfg.labels\n",
    "                   },\n",
    "                   frame_len=FRAME_LEN, frame_overlap = (WINDOW_SIZE - FRAME_LEN)/2,\n",
    "                   offset=0)\n",
    "\n",
    "    wf = wave.open(wave_file, 'rb')\n",
    "    data = wf.readframes(CHUNK_SIZE)\n",
    "\n",
    "    while len(data) > 0:\n",
    "\n",
    "        data = wf.readframes(CHUNK_SIZE)\n",
    "        signal = np.frombuffer(data, dtype=np.int16)\n",
    "        mbn_result = mbn.transcribe(signal)\n",
    "        \n",
    "        if len(mbn_result):\n",
    "            print(mbn_result)\n",
    "            \n",
    "    mbn.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "demo_wave = 'SpeechCommands_demo.wav'\n",
    "if not os.path.exists(demo_wave):\n",
    "    !wget \"https://dldata-public.s3.us-east-2.amazonaws.com/SpeechCommands_demo.wav\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "wave_file = demo_wave\n",
    "\n",
    "CHANNELS = 1\n",
    "audio, sample_rate = librosa.load(wave_file, sr=SAMPLE_RATE)\n",
    "dur = librosa.get_duration(audio)\n",
    "print(dur)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "ipd.Audio(audio, rate=sample_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Ground-truth is Yes No\n",
    "offline_inference(wave_file, STEP, WINDOW_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Online inference through microphone"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please note MatchBoxNet and VAD model are not perfect for various microphone input and you might need to finetune on your input and play with different parameter. \\\n",
    "**We also recommend to use a headphone.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vad_threshold = 0.8 \n",
    "\n",
    "STEP = 0.1 \n",
    "WINDOW_SIZE = 0.15\n",
    "mbn_WINDOW_SIZE = 1\n",
    "\n",
    "CHANNELS = 1 \n",
    "RATE = SAMPLE_RATE\n",
    "FRAME_LEN = STEP # use step of vad inference as frame len\n",
    "\n",
    "CHUNK_SIZE = int(STEP * RATE)\n",
    "vad = FrameASR(model_definition = {\n",
    "                   'task': 'vad',\n",
    "                   'sample_rate': SAMPLE_RATE,\n",
    "                   'AudioToMFCCPreprocessor': vad_cfg.preprocessor,\n",
    "                   'JasperEncoder': vad_cfg.encoder,\n",
    "                   'labels': vad_cfg.labels\n",
    "               },\n",
    "               frame_len=FRAME_LEN, frame_overlap=(WINDOW_SIZE - FRAME_LEN) / 2, \n",
    "               offset=0)\n",
    "\n",
    "mbn = FrameASR(model_definition = {\n",
    "                       'task': 'mbn',\n",
    "                       'sample_rate': SAMPLE_RATE,\n",
    "                       'AudioToMFCCPreprocessor': mbn_cfg.preprocessor,\n",
    "                       'JasperEncoder': mbn_cfg.encoder,\n",
    "                       'labels': mbn_cfg.labels\n",
    "                   },\n",
    "                   frame_len=FRAME_LEN, frame_overlap = (mbn_WINDOW_SIZE-FRAME_LEN)/2,\n",
    "                   offset=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vad.reset()\n",
    "mbn.reset()\n",
    "\n",
    "# Setup input device\n",
    "p = pa.PyAudio()\n",
    "print('Available audio input devices:')\n",
    "input_devices = []\n",
    "for i in range(p.get_device_count()):\n",
    "    dev = p.get_device_info_by_index(i)\n",
    "    if dev.get('maxInputChannels'):\n",
    "        input_devices.append(i)\n",
    "        print(i, dev.get('name'))\n",
    "\n",
    "if len(input_devices):\n",
    "    dev_idx = -2\n",
    "    while dev_idx not in input_devices:\n",
    "        print('Please type input device ID:')\n",
    "        dev_idx = int(input())\n",
    "\n",
    "    \n",
    "    def callback(in_data, frame_count, time_info, status):\n",
    "        \"\"\"\n",
    "        callback function for streaming audio and performing inference\n",
    "        \"\"\"\n",
    "        signal = np.frombuffer(in_data, dtype=np.int16)\n",
    "        vad_result = vad.transcribe(signal) \n",
    "        mbn_result = mbn.transcribe(signal) \n",
    "        \n",
    "        if len(vad_result):\n",
    "            # if speech prob is higher than threshold, we decide it contains speech utterance \n",
    "            # and activate MatchBoxNet \n",
    "            if vad_result[3] >= vad_threshold: \n",
    "                print(mbn_result) # print mbn result when speech present\n",
    "            else:\n",
    "                print(\"no-speech\")\n",
    "        return (in_data, pa.paContinue)\n",
    "\n",
    "    # streaming\n",
    "    stream = p.open(format=pa.paInt16,\n",
    "                    channels=CHANNELS,\n",
    "                    rate=SAMPLE_RATE,\n",
    "                    input=True,\n",
    "                    input_device_index=dev_idx,\n",
    "                    stream_callback=callback,\n",
    "                    frames_per_buffer=CHUNK_SIZE)\n",
    "\n",
    "    \n",
    "    print('Listening...')\n",
    "    stream.start_stream()\n",
    "    \n",
    "    # Interrupt kernel and then speak for a few more words to exit the pyaudio loop !\n",
    "    try:\n",
    "        while stream.is_active():\n",
    "            time.sleep(0.1)\n",
    "    finally:        \n",
    "        stream.stop_stream()\n",
    "        stream.close()\n",
    "        p.terminate()\n",
    "        print()\n",
    "        print(\"PyAudio stopped\")\n",
    "    \n",
    "else:\n",
    "    print('ERROR: No audio input device found.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## ONNX Deployment\n",
    "You can also export the model to ONNX file and deploy it to TensorRT or MS ONNX Runtime inference engines. If you don't have one installed yet, please run:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --upgrade onnxruntime # for gpu, use onnxruntime-gpu\n",
    "# !mkdir -p ort\n",
    "# %cd ort\n",
    "# !git clone --depth 1 --branch v1.8.0 https://github.com/microsoft/onnxruntime.git .\n",
    "# !./build.sh --skip_tests --config Release --build_shared_lib --parallel --use_cuda --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu --build_wheel\n",
    "# !pip install ./build/Linux/Release/dist/onnxruntime*.whl\n",
    "# %cd .."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then just replace `infer_signal` implementation with this code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import onnxruntime\n",
    "mbn_model.export('mbn.onnx')\n",
    "ort_session = onnxruntime.InferenceSession('mbn.onnx')\n",
    "\n",
    "def to_numpy(tensor):\n",
    "    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()\n",
    "\n",
    "def infer_signal(signal):\n",
    "    data_layer.set_signal(signal)\n",
    "    batch = next(iter(data_loader))\n",
    "    audio_signal, audio_signal_len = batch\n",
    "    audio_signal, audio_signal_len = audio_signal.to(mbn_model.device), audio_signal_len.to(mbn_model.device)\n",
    "    processed_signal, processed_signal_len = mbn_model.preprocessor(\n",
    "        input_signal=audio_signal, length=audio_signal_len,\n",
    "    )\n",
    "    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(processed_signal), }\n",
    "    ologits = ort_session.run(None, ort_inputs)\n",
    "    alogits = np.asarray(ologits)\n",
    "    logits = torch.from_numpy(alogits[0])\n",
    "    return logits"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}